# coding=utf-8
# Copyright 2024 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Constellation consapsule autoencoder implementation."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from monty.collections import AttrDict
import numpy as np
import sonnet as snt
import tensorflow.compat.v1 as tf
from tensorflow.compat.v1 import nest
import tensorflow_probability as tfp

from stacked_capsule_autoencoders.capsules import capsule as _capsule
from stacked_capsule_autoencoders.capsules import math_ops
from stacked_capsule_autoencoders.capsules import neural
from stacked_capsule_autoencoders.capsules import plot
from stacked_capsule_autoencoders.capsules import tensor_ops
from stacked_capsule_autoencoders.capsules.eval import bipartite_match
from stacked_capsule_autoencoders.capsules.models.model import Model

tfd = tfp.distributions


class ConstellationCapsule(snt.AbstractModule):
  """Capsule decoder for constellations."""

  def __init__(self, n_caps, n_caps_dims, n_votes, **capsule_kwargs):
    """Builds the module.

    Args:
      n_caps: int, number of capsules.
      n_caps_dims: int, number of capsule coordinates.
      n_votes: int, number of votes generated by each capsule.
      **capsule_kwargs: kwargs passed to capsule layer.
    """
    super(ConstellationCapsule, self).__init__()
    self._n_caps = n_caps
    self._n_caps_dims = n_caps_dims
    self._n_votes = n_votes
    self._capsule_kwargs = capsule_kwargs

  def _build(self, h, x, presence=None):
    """Builds the module.

    Args:
      h: Tensor of encodings of shape [B, n_enc_dims].
      x: Tensor of inputs of shape [B, n_points, n_input_dims]
      presence: Tensor of shape [B, n_points, 1] or None; if it exists, it
        indicates which input points exist.

    Returns:
      A bunch of stuff.
    """
    batch_size, n_input_points, _ = x.shape.as_list()

    capsule = _capsule.CapsuleLayer(self._n_caps, self._n_caps_dims,
                                    self._n_votes, **self._capsule_kwargs)

    res = capsule(h)
    res.transform = res.vote
    res.vote = math_ops.apply_transform(transform=res.vote)
    for k, v in res.items():
      if v.shape.ndims > 0:
        res[k] = snt.MergeDims(1, 2)(v)

    likelihood = _capsule.OrderInvariantCapsuleLikelihood(self._n_votes,
                                                          res.vote, res.scale,
                                                          res.vote_presence)
    ll_res = likelihood(x, presence)
    res.update(ll_res._asdict())

    # post processing
    mixing_probs = tf.nn.softmax(ll_res.mixing_logits, 1)
    prior_mixing_log_prob = tf.log(1. / n_input_points)
    mixing_kl = mixing_probs * (ll_res.mixing_log_prob - prior_mixing_log_prob)
    mixing_kl = tf.reduce_mean(tf.reduce_sum(mixing_kl, -1))

    wins_per_caps = tf.one_hot(ll_res.is_from_capsule, depth=self._n_caps)

    if presence is not None:
      wins_per_caps *= tf.expand_dims(presence, -1)

    wins_per_caps = tf.reduce_sum(wins_per_caps, 1)

    has_any_wins = tf.to_float(tf.greater(wins_per_caps, 0))
    should_be_active = tf.to_float(tf.greater(wins_per_caps, 1))

    sparsity_loss = tf.nn.sigmoid_cross_entropy_with_logits(
        labels=should_be_active, logits=res.pres_logit_per_caps)

    sparsity_loss = tf.reduce_sum(sparsity_loss * has_any_wins, -1)
    sparsity_loss = tf.reduce_mean(sparsity_loss)

    caps_presence_prob = tf.reduce_max(
        tf.reshape(res.vote_presence,
                   [batch_size, self._n_caps, self._n_votes]), 2)

    res.update(dict(
        mixing_kl=mixing_kl,
        sparsity_loss=sparsity_loss,
        caps_presence_prob=caps_presence_prob,
        mean_scale=tf.reduce_mean(res.scale)
    ))
    return res


class ConstellationAutoencoder(Model):
  """Capsule autoencoder."""

  def __init__(self,
               encoder,
               decoder,
               input_key='corners',
               presence_key='presence',
               n_classes=2,
               dynamic_l2_weight=0.,
               mixing_kl_weight=0.,
               sparsity_weight=0.,
               prior_sparsity_loss_type='kl',
               prior_within_example_sparsity_weight=0.,
               prior_between_example_sparsity_weight=0.,
               prior_within_example_constant=0.,
               posterior_sparsity_loss_type='kl',
               posterior_within_example_sparsity_weight=0.,
               posterior_between_example_sparsity_weight=0.,):

    super(ConstellationAutoencoder, self).__init__()
    self._encoder = encoder
    self._decoder = decoder
    self._input_key = input_key
    self._presence_key = presence_key
    self._n_classes = n_classes
    self._mixing_kl_weight = mixing_kl_weight
    self._sparsity_weight = sparsity_weight
    self._dynamic_l2_weight = dynamic_l2_weight

    self._prior_sparsity_loss_type = prior_sparsity_loss_type
    self._prior_within_example_sparsity_weight = prior_within_example_sparsity_weight
    self._prior_between_example_sparsity_weight = prior_between_example_sparsity_weight
    self._prior_within_example_constant = prior_within_example_constant
    self._posterior_sparsity_loss_type = posterior_sparsity_loss_type
    self._posterior_within_example_sparsity_weight = posterior_within_example_sparsity_weight
    self._posterior_between_example_sparsity_weight = posterior_between_example_sparsity_weight

  def _build(self, data):

    x = data[self._input_key]
    presence = data[self._presence_key] if self._presence_key else None

    inputs = nest.flatten(x)
    if presence is not None:
      inputs.append(presence)

    h = self._encoder(*inputs)
    res = self._decoder(h, *inputs)

    n_points = int(res.posterior_mixing_probs.shape[1])
    mass_explained_by_capsule = tf.reduce_sum(res.posterior_mixing_probs, 1)

    (res.posterior_within_sparsity_loss,
     res.posterior_between_sparsity_loss) = _capsule.sparsity_loss(
         self._posterior_sparsity_loss_type,
         mass_explained_by_capsule / n_points,
         num_classes=self._n_classes)

    (res.prior_within_sparsity_loss,
     res.prior_between_sparsity_loss) = _capsule.sparsity_loss(
         self._prior_sparsity_loss_type,
         res.caps_presence_prob,
         num_classes=self._n_classes,
         within_example_constant=self._prior_within_example_constant)

    return res

  def _loss(self, data, res):

    loss = (
        -res.log_prob
        # + self._mixing_kl_weight * res.mixing_kl
        + self._sparsity_weight * res.sparsity_loss
        + self._dynamic_l2_weight * res.dynamic_weights_l2
        #
        + (self._posterior_within_example_sparsity_weight
           * res.posterior_within_sparsity_loss)
        - (self._posterior_between_example_sparsity_weight
           * res.posterior_between_sparsity_loss)
        + (self._prior_within_example_sparsity_weight
           * res.prior_within_sparsity_loss)
        - (self._prior_between_example_sparsity_weight
           * res.prior_between_sparsity_loss)
    )

    return loss

  def _report(self, data, res):
    reports = super(ConstellationAutoencoder, self)._report(data, res)

    # rendered = getattr(self, 'rendered', None)
    try:
      presence = data.presence
    except AttributeError:
      presence = None

    n_caps = self._decoder._n_caps  # pylint:disable=protected-access

    is_from_capsule = res.is_from_capsule
    pres = tf.cast(presence, is_from_capsule.dtype)
    capsule_one_hot = tf.one_hot(
        (is_from_capsule + pres) * pres, depth=n_caps + 1)[Ellipsis, 1:]

    num_per_group = tf.reduce_sum(capsule_one_hot, 1)
    num_per_group_per_batch = tf.reduce_mean(tf.to_float(num_per_group), 0)

    reports.update({
        'votes_per_capsule_{}'.format(k): v
        for k, v in enumerate(tf.unstack(num_per_group_per_batch))
    })

    reports.segm_acc = tensor_ops.py_func_metric(
        eval_segmentation, [is_from_capsule, data.pattern_id, presence])
    return reports

  def _plot(self, data, res, name=None):

    presence = data.presence
    n_caps = self._decoder._n_caps  # pylint:disable=protected-access

    pred_presence = res.vote_presence
    capsule_idx = tf.expand_dims(tf.range(res.vote.shape[1]), 0) // n_caps
    capsule_idx = snt.TileByDim([0], [res.vote.shape[0]])(capsule_idx)
    pred_presence = presence

    rendered = plot.render_constellations(
        res.winner,
        res.is_from_capsule,
        # res.votes, capsule_idx,
        gt_points=data.corners,
        canvas_size=(64, 64),
        n_caps=n_caps,
        gt_presence=presence,
        pred_presence=pred_presence,
        caps_presence_prob=res.caps_presence_prob,
    )

    plot_params = dict(imgs=dict(zoom=3.))
    plot_dict = dict(imgs=rendered[:32])
    return plot_dict, plot_params


# def bipartite_match(pred, gt, presence):
#   """Performs bipartite likelihood matching between 'pred' and 'gt' scores."""
#
#   n_gt_labels = np.unique(gt).shape[0]
#   n_pred_labels = np.unique(pred).shape[0]
#
#   cost_matrix = np.zeros([n_gt_labels, n_pred_labels], dtype=np.int32)
#   for label in range(n_gt_labels):
#     label_idx = (gt == label)
#     for new_label in range(n_pred_labels):
#       errors = np.equal(pred[label_idx], new_label) * presence[label_idx]
#       num_errors = -errors.sum()
#       cost_matrix[label, new_label] = num_errors
#
#   row_idx, col_idx = linear_sum_assignment(cost_matrix)
#   num_correct = -cost_matrix[row_idx, col_idx].sum()
#   return num_correct


def eval_segmentation(pred, gt, presence=None):
  """Evaluates segmentation accuracy."""

  if presence is None:
    presence = np.ones_like(gt)

  num_correct = 0
  for i in range(pred.shape[0]):
    res = bipartite_match(pred[i], gt[i], presence=presence[i])
    num_correct += res.num_correct

  return np.float32(float(num_correct) / presence.sum())


class ConstellationDecoder(snt.AbstractModule):
  """Capsule decoder for constellations."""

  _n_caps_dims = 2

  def __init__(self, n_caps, n_votes, n_hiddens):
    """Builds the module.

    Args:
      n_caps: int, number of capsules.
      n_votes: int, number of votes generated by each capsule.
      n_hiddens: int, number of hidden units.
    """
    super(ConstellationDecoder, self).__init__()
    self._n_caps = n_caps
    self._n_votes = n_votes
    self._n_hiddens = n_hiddens

  def _build(self, h, x, presence=None):
    """Builds the module.

    Args:
      h: Tensor of encodings of shape [B, n_enc_dims].
      x: Tensor of inputs of shape [B, n_points, n_input_dims]
      presence: Tensor of shape [B, n_points, 1] or None; if it exists, it
        indicates which input points exist.

    Returns:
      A bunch of stuff.
    """
    batch_size, n_input_points, _ = x.shape.as_list()
    res = AttrDict(
        dynamic_weights_l2=tf.constant(0.)
    )

    output_shapes = (
        [1],  # per-capsule presence
        [self._n_votes],  # per-vote-presence
        [self._n_votes],  # per-vote scale
        [self._n_votes, self._n_caps_dims]
    )

    splits = [np.prod(i).astype(np.int32) for i in output_shapes]
    n_outputs = sum(splits)
    batch_mlp = neural.BatchMLP([self._n_hiddens, self._n_hiddens, n_outputs],
                                use_bias=True)

    all_params = batch_mlp(h)
    all_params = tf.split(all_params, splits, -1)
    batch_shape = [batch_size, self._n_caps]
    all_params = [tf.reshape(i, batch_shape + s)
                  for (i, s) in zip(all_params, output_shapes)]

    def add_noise(tensor):
      return tf.random.uniform(tensor.shape, minval=-.5, maxval=.5) * 4.

    res.pres_logit_per_caps = add_noise(all_params[0])
    res.pres_logit_per_vote = add_noise(all_params[1])
    res.scale = tf.nn.softplus(all_params[2] + .5) + 1e-6
    res.vote_presence = (tf.nn.sigmoid(res.pres_logit_per_caps)
                         * tf.nn.sigmoid(res.pres_logit_per_vote))
    res.vote = all_params[3]

    for k, v in res.items():
      if v.shape.ndims > 0:
        res[k] = snt.MergeDims(1, 2)(v)

    likelihood = _capsule.OrderInvariantCapsuleLikelihood(self._n_votes,
                                                          res.vote, res.scale,
                                                          res.vote_presence)
    ll_res = likelihood(x, presence)
    res.update(ll_res._asdict())

    # post processing
    mixing_probs = tf.nn.softmax(ll_res.mixing_logits, 1)
    prior_mixing_log_prob = tf.log(1. / n_input_points)
    mixing_kl = mixing_probs * (ll_res.mixing_log_prob - prior_mixing_log_prob)
    mixing_kl = tf.reduce_mean(tf.reduce_sum(mixing_kl, -1))

    wins_per_caps = tf.one_hot(ll_res.is_from_capsule, depth=self._n_caps)

    if presence is not None:
      wins_per_caps *= tf.expand_dims(presence, -1)

    wins_per_caps = tf.reduce_sum(wins_per_caps, 1)

    has_any_wins = tf.to_float(tf.greater(wins_per_caps, 0))
    should_be_active = tf.to_float(tf.greater(wins_per_caps, 1))

    sparsity_loss = tf.nn.sigmoid_cross_entropy_with_logits(
        labels=should_be_active, logits=res.pres_logit_per_caps)

    sparsity_loss = tf.reduce_sum(sparsity_loss * has_any_wins, -1)
    sparsity_loss = tf.reduce_mean(sparsity_loss)

    caps_presence_prob = tf.reduce_max(
        tf.reshape(res.vote_presence,
                   [batch_size, self._n_caps, self._n_votes]), 2)

    res.update(dict(
        mixing_kl=mixing_kl,
        sparsity_loss=sparsity_loss,
        caps_presence_prob=caps_presence_prob,
        mean_scale=tf.reduce_mean(res.scale)
    ))
    return res
